package upenn.junto.algorithm.parallel;

/**
 * Copyright 2011 Partha Pratim Talukdar
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TObjectDoubleIterator;

import java.io.IOException;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.hadoop.mapred.TextOutputFormat;
import org.apache.hadoop.mapred.jobcontrol.Job;

import upenn.junto.util.*;
import upenn.junto.config.*;

public class MADHadoop {
	
   private static String _kDelim = "\t";
	
   public static class MADHadoopMap extends MapReduceBase 
   			implements Mapper<LongWritable, Text, Text, Text> {
     private Text word = new Text();

     public void map(LongWritable key, Text value,
    		 		 OutputCollector<Text, Text> output,
    		 		 Reporter reporter) throws IOException {
       /////
       // Constructing the vertex from the string representation
       /////
       String line = value.toString();
 
       // id gold_label injected_labels estimated_labels neighbors rw_probabilities 
       String[] fields = line.split(_kDelim);       
       TObjectDoubleHashMap neighbors = CollectionUtil.String2Map(fields[4]);       
       TObjectDoubleHashMap rwProbabilities = CollectionUtil.String2Map(fields[5]);
       
       // If the current node is a seed node but there is no
       // estimate label information yet, then transfer the seed label
       // to the estimated label distribution. Ideally, this is likely
       // to be used in the map of the very first iteration.
       boolean isSeedNode = fields[2].length() > 0 ? true : false;
       if (isSeedNode && fields[3].length() == 0) {
    	   fields[3] = fields[2];
       }

       // TODO(partha): move messages to ProtocolBuffers
       
       // Send two types of messages:
       //   -- self messages which will store the injection labels and
       //        random walk probabilities.
       //   -- messages to neighbors about current estimated scores
       //        of the node.
       //
       // message to self
       output.collect(new Text(fields[0]), new Text("labels" + _kDelim + line));

       // message to neighbors
       TObjectDoubleIterator neighIterator = neighbors.iterator();
       while (neighIterator.hasNext()) {
    	   neighIterator.advance();
    	   
    	   // message (neighbor_node, current_node + DELIM + curr_node_label_scores
    	   output.collect(new Text((String) neighIterator.key()),
    			   		  new Text("labels" + _kDelim + fields[0] + _kDelim + fields[3]));
    	   
    	   // message (neighbor_node, curr_node + DELIM + curr_node_edge_weights + DELIM curr_node_cont_prob
    	   assert(neighbors.containsKey((String) neighIterator.key()));
    	   output.collect(new Text((String) neighIterator.key()),
    			   		  new Text("edge_info" + _kDelim +
    			   				   fields[0] + _kDelim +
    			   				   neighbors.get((String) neighIterator.key()) + _kDelim +
    			   				   rwProbabilities.get(Constants._kContProb)));
       }
     }
   }
	 	
   public static class MADHadoopReduce extends MapReduceBase implements Reducer<Text, Text, Text, Text> {	   
	 private static double mu1;
	 private static double mu2;
	 private static double mu3;
	 private static int keepTopKLabels;

	 public void configure(JobConf conf) {
		 mu1 = Double.parseDouble(conf.get("mu1"));
		 mu2 = Double.parseDouble(conf.get("mu2"));
		 mu3 = Double.parseDouble(conf.get("mu3"));
		 keepTopKLabels = Integer.parseInt(conf.get("keepTopKLabels"));
	 }
	 
     public void reduce(Text key, Iterator<Text> values,
    		 			OutputCollector<Text, Text> output, Reporter reporter) throws IOException {    	 
       // new scores estimated for the current node
       TObjectDoubleHashMap newEstimatedScores = new TObjectDoubleHashMap();
       
       // set to true only if the message sent to itself is found.
       boolean isSelfMessageFound = false;
       
       String vertexId = key.toString();
       String vertexString = "";
       
       TObjectDoubleHashMap neighbors = null;
       TObjectDoubleHashMap randWalkProbs = null;
       
       HashMap<String, String> neighScores =
    	   				new HashMap<String, String>();
       
       TObjectDoubleHashMap incomingEdgeWeights = new TObjectDoubleHashMap();
       TObjectDoubleHashMap neighborContProb = new TObjectDoubleHashMap();
       
       int totalMessagesReceived = 0;
       
       // iterate over all the messages received at the node
       while (values.hasNext()) {
    	   ++totalMessagesReceived;

    	   String val = values.next().toString();
    	   String[] fields = val.split(_kDelim);
    	   
    	   // first field represents the type of message
    	   String msgType = fields[0];
    	   
    	   if (fields[0].equals("labels")) {    		   
	    	   // self-message check
	    	   if (vertexId.equals(fields[1])) {
	    		   isSelfMessageFound = true;
	    		   vertexString = val;	    		 
	
	    		   TObjectDoubleHashMap injLabels = CollectionUtil.String2Map(fields[3]);
	    		   neighbors = CollectionUtil.String2Map(neighbors, fields[5]);
	    		   randWalkProbs = CollectionUtil.String2Map(fields[6]);
	    		   
	    		   if (injLabels.size() > 0) {    		   
		    		   // add injected labels to the estimated scores.
		    		   ProbUtil.AddScores(newEstimatedScores,
		    				   		   mu1 * randWalkProbs.get(Constants._kInjProb),
		    				   		   injLabels);
	    		   }
	    	   } else {
	    		   // an empty third field represents that the
	    		   // neighbor has no valid label assignment yet.
	    		   if (fields.length > 2) {
	    			   neighScores.put(fields[1], fields[2]);
	    		   }
	    	   }
    	   } else if (msgType.equals("edge_info")) {
    		   // edge_info neigh_vertex incoming_edge_weight cont_prob
    		   String neighId = fields[1];
    		   
    		   if (!incomingEdgeWeights.contains(neighId)) {
    			   incomingEdgeWeights.put(neighId, Double.parseDouble(fields[2]));
    		   }
    		   
    		   if (!neighborContProb.contains(neighId)) {
    			   neighborContProb.put(neighId, Double.parseDouble(fields[3]));
    		   }
    	   } else {
    		   MessagePrinter.PrintAndDie("Invalid message: " + val);
    	   }
       }

       // terminate if message from self is not received.
       if (!isSelfMessageFound) {
    	   MessagePrinter.PrintAndDie("Self message not received for node " + vertexId);
       }
       
       // collect neighbors' label distributions and create one single
       // label distribution
       TObjectDoubleHashMap weightedNeigLablDist = new TObjectDoubleHashMap();
       Iterator<String> neighIter = neighScores.keySet().iterator();
       while (neighIter.hasNext()) {
    	   String neighName = neighIter.next();
    	   
    	   double mult = randWalkProbs.get(Constants._kContProb) * neighbors.get(neighName) +
    	   					neighborContProb.get(neighName) * incomingEdgeWeights.get(neighName);
     	   
    	   ProbUtil.AddScores(weightedNeigLablDist, // newEstimatedScores,
    			    	   mu2 * mult,
    			   		   CollectionUtil.String2Map(neighScores.get(neighName)));
       }
       
       // now add the collective neighbor label distribution to
       // the estimate of the current node's labels.
       ProbUtil.AddScores(newEstimatedScores,
    		              1.0, weightedNeigLablDist);
       
	   // add dummy label scores
	   ProbUtil.AddScores(newEstimatedScores,
			   		   		mu3 * randWalkProbs.get(Constants._kTermProb),
			   		   		Constants.GetDummyLabelDist());
       
	   if (keepTopKLabels < Integer.MAX_VALUE) {
		   ProbUtil.KeepTopScoringKeys(newEstimatedScores, keepTopKLabels);
	   }

	   ProbUtil.DivScores(newEstimatedScores, 
			   GetNormalizationConstant(neighbors, randWalkProbs,
					   					incomingEdgeWeights, neighborContProb,
					     				mu1, mu2, mu3));
       
       // now reconstruct the vertex representation (with the new estimated scores)
       // so that the output from the current mapper can be used as input in next
       // iteration's mapper.
       String[] vertexFields = vertexString.split(_kDelim);
       
       // replace estimated scores with the new ones.
       // Skip the first two fields as they contained the message header and
       // vertex id respectively.
       String[] newVertexFields = new String[vertexFields.length - 2];
       for (int i = 2; i < vertexFields.length; ++i) {
    	   newVertexFields[i - 2] = vertexFields[i]; 
       }
       newVertexFields[2] = CollectionUtil.Map2String(newEstimatedScores);

       output.collect(key, new Text(CollectionUtil.Join(newVertexFields, _kDelim)));
     }
     
     public double GetNormalizationConstant(
    		 			TObjectDoubleHashMap neighbors,
    		 			TObjectDoubleHashMap randWalkProbs,
    		 			TObjectDoubleHashMap incomingEdgeWeights,
    		 			TObjectDoubleHashMap neighborContProb,
    		 			double mu1, double mu2, double mu3) {
	    double mii = 0;
	    double totalNeighWeight = 0;
	    TObjectDoubleIterator nIter = neighbors.iterator();
	    while (nIter.hasNext()) {
	      nIter.advance();
	      totalNeighWeight +=
	    	  randWalkProbs.get(Constants._kContProb) * nIter.value();

	      String neighName = (String) nIter.key();
	      totalNeighWeight += neighborContProb.get(neighName) *
	        				    incomingEdgeWeights.get(neighName);
	    }
			
	    // mu1 x p^{inj} +
	    //   0.5 * mu2 x \sum_j (p_{i}^{cont} W_{ij} + p_{j}^{cont} W_{ji}) + 
	    //   mu3
	    mii = mu1 * randWalkProbs.get(Constants._kInjProb) +
	      		/*0.5 **/ mu2 * totalNeighWeight +
	      		mu3;

	    return (mii);
	  }
   }
 	
   public static void main(String[] args) throws Exception {
	 Hashtable config = ConfigReader.read_config(args);  

     String baseInputFilePat = Defaults.GetValueOrDie(config, "hdfs_input_pattern");
     String baseOutputFilePat = Defaults.GetValueOrDie(config, "hdfs_output_base");
     int numIterations = Integer.parseInt(Defaults.GetValueOrDie(config, "iters"));
     int numReducers = Defaults.GetValueOrDefault((String) config.get("num_reducers"), 10);

     String currInputFilePat = baseInputFilePat;
     String currOutputFilePat = "";
     for (int iter = 1; iter <= numIterations; ++iter) {
	     JobConf conf = new JobConf(MADHadoop.class);
	     conf.setJobName("mad_hadoop");
	
	     conf.setOutputKeyClass(Text.class);
	     conf.setOutputValueClass(Text.class);
	
	     conf.setMapperClass(MADHadoopMap.class);
	     // conf.setCombinerClass(MADHadoopReduce.class);
	     conf.setReducerClass(MADHadoopReduce.class);
	     conf.setNumReduceTasks(numReducers);
	
	     conf.setInputFormat(TextInputFormat.class);
	     conf.setOutputFormat(TextOutputFormat.class);
	     
	     // hyperparameters
	     conf.set("mu1", Defaults.GetValueOrDie(config, "mu1"));
	     conf.set("mu2", Defaults.GetValueOrDie(config, "mu2"));
	     conf.set("mu3", Defaults.GetValueOrDie(config, "mu3"));
	     conf.set("keepTopKLabels",
	    		  Defaults.GetValueOrDefault((String) config.get("keep_top_k_labels"),
	    				  					 Integer.toString(Integer.MAX_VALUE)));

    	 if (iter > 1) {
    		 // output from last iteration is the input for current iteration
    		 currInputFilePat = currOutputFilePat + "/*";
    	 }
    	 FileInputFormat.setInputPaths(conf, new Path(currInputFilePat));
 
    	 currOutputFilePat = baseOutputFilePat + "_iter_" + iter;
	     FileOutputFormat.setOutputPath(conf, new Path(currOutputFilePat));

    	 JobClient.runJob(conf);
     }
   }
}
